You have been provided actual data from the covid-19 pandemic, as reported by state and national governments. These data have already had spikes or errors in reporting corrected and redistributed. Attached are data from the US (excluding Washington state):
These are the only 4 inputs that IHME uses for our first-stage deaths model, which produces 14-day forecasts that are then used in an SEIR model.
library(here)
library(tidyverse)
setwd(here::here())
data = read.csv("covid_data_cases_deaths_hosp.csv")
dim(data) #7914 observations, 9 variables
length(unique(data$location_id)) #checking that there are 51 state IDs
length(unique(data$Province.State)) #there are 51 unique state names
length(unique(data$Date)) #there are 194 unique dates
min(data$Date) #43856 -- converts to January 28 2020
max(data$Date) #44049 -- converts to Aug 8 2020
length(unique(data$US.STATE)) #US.STATE doesn't appear to mean anything as it is just 1 data = data %>%
mutate(
newdate = as.Date(Date, origin="1900-01-01"),
cases = if_else(Confirmed < 0, 0, Confirmed),
deaths = if_else(Deaths < 0, 0, Deaths),
hospitalizations = if_else(Hospitalizations < 0,0, Hospitalizations),
cases100k = (cases/population)*100000,
deaths100k = (deaths/population)*100000,
hosp100k = (hospitalizations/population)*100000)This table is arranged by most number of cases per 100k to least
summarytable = data %>%
group_by(Province.State) %>%
summarise(median_cases = median(cases100k, na.rm=TRUE),
median_deaths = median(deaths100k, na.rm=TRUE),
median_hosp = median(hosp100k, na.rm=TRUE)) %>%
arrange(desc(median_cases))
summarytableWe expect there to be a pretty substantial correlation – the higher the case count, the higher the number of deaths in each state. We find that there is an overall correlation coefficient of R = 0.82 between cases per 100,000 and deaths per 100,000. We also provide a correlation between cases and deaths by state in table and figure form (though not all states have correlation coefficients due to incomplete data)
## Loading required package: viridisLite
ggplot(data=data) +
geom_point(aes(x=cases100k, y=deaths100k), size=0.5) +
xlab('Cases/100k population') +
ylab ('Deaths/100k population') +
facet_wrap(~Province.State)+
theme_bw()#correlation coefficient by state
groupcorr = data %>%
group_by(Province.State) %>%
summarise(correlation = cor(cases100k, deaths100k)) %>%
filter(!is.na(correlation))
groupcorrggplot(groupcorr, aes(x = reorder(Province.State, correlation), y = correlation)) +
geom_bar(stat = "identity", fill = "steelblue") +
coord_flip() +
labs(title = "Correlation between Cases and Deaths per 100k by State",
x = "State",
y = "Correlation Coefficient") +
theme_minimal()#overall correlation coefficient
cor(data$cases100k, data$deaths100k, use = "complete.obs",
method = c("pearson", "kendall", "spearman")) #0.82## [1] 0.8214892
For hospitalizations, we will subset the data to rows that only contain hospitalization data. This way, we will avoid plotting blank points.
We see that there is an even higher correlation between deaths and hospitalizations when broken down by state, which is expected. The overall coefficient between deaths and hospitalizations is 0.93.
hospdata = data %>%
filter(!is.na(hospitalizations))
ggplot(data=hospdata) +
geom_point(aes(x=hosp100k, y=deaths100k), size=0.5) +
xlab('Hospitalizations/100k population') +
ylab ('Deaths/100k population') +
facet_wrap(~Province.State)+
theme_bw()groupcorr2= hospdata %>%
group_by(Province.State) %>%
summarise(correlation = cor(hosp100k, deaths100k, use="complete.obs")) %>%
filter(!is.na(correlation))
groupcorr2ggplot(groupcorr2, aes(x = reorder(Province.State, correlation), y = correlation)) +
geom_bar(stat = "identity", fill = "lightgreen") +
coord_flip() +
labs(title = "Correlation between Hospitalizations and Deaths per 100k by State",
x = "State",
y = "Correlation coefficient") +
theme_minimal()cor(hospdata$hosp100k, hospdata$deaths100k, use = "complete.obs",
method = c("pearson", "kendall", "spearman")) #0.93## [1] 0.9253078
Here we plot a map showing how cases and deaths vary spatially. For this example, I have only selected the cumulative case and death count on the last recorded day (8/8/2020) since the map is 1 dimensional in time. I have categorized the cumulative case adn death counts into a 9 level bivariate category to illustrate the relationship between COVID-19 cases and deaths by state: low case/low death count, low case/med death count, low case/high death, med case/low death, and so on and so forth.
The map below shows how cases and deaths are correlated throughout the country on 8/8/2020, noting that some states in the South (Louisiana, North Carolina, South Carolina) and East (New York, Massachusetts) having a high case and death count.
#here we filter out the relevant variables to make the dataset more manageable
map_data = data %>%
select(Province.State, newdate, cases100k, deaths100k, hosp100k) %>%
filter(newdate == max(newdate))
#creates a dataset containing case and death count data
statemap_join = left_join(map_data, statemap, by="Province.State")
#bivariate map file
casedeathmap = bi_class(statemap_join, x = cases100k, y = deaths100k, style="quantile") %>%
slice(-1)
casedeathmap = st_as_sf(casedeathmap)
color_palette = RColorBrewer::brewer.pal(9, "YlOrRd")
m5 = tm_shape(casedeathmap, projection = 2163, unit = "mi") +
tm_polygons("bi_class", palette = RColorBrewer::brewer.pal(9, "YlOrRd"),
border.col = "white", lwd = 0.1, legend.show=FALSE) +
tm_shape(statemap) +
tm_borders(col="black", lwd=0.3, lty="solid")+
tm_text("Province.State", size=0.6, root=3, remove.overlap=TRUE)+
tm_layout(
outer.margins = 0,
asp = 0,
legend.width=1,
legend.height=0.5
)+
tm_add_legend(
type = "fill",
col = color_palette,
title = "COVID Cases and Deaths per 100k",
is.portrait = TRUE,
labels = c("Low Cases/Low Deaths", "Low Cases/Med Deaths", "Low Cases/High Deaths",
"Med Cases/Low Deaths", "Med Cases/Med Deaths", "Med Cases/High Deaths",
"High Cases/Low Deaths", "High Cases/Med Deaths", "High Cases/High Deaths")
)
m5#check if variance is greater than mean for cases
with(data, tapply(cases, Province.State, function(x) {
sprintf("M (var) = %1.2f (%1.2f)", mean(x,na.rm=TRUE), var(x, na.rm=TRUE))
}))## Alabama
## "M (var) = 133.73 (1384.49)" "M (var) = 25704.53 (742795215.61)"
## Alaska Arizona
## "M (var) = 774.62 (712484.96)" "M (var) = 36259.52 (2990380977.45)"
## Arkansas California
## "M (var) = 12343.89 (187082538.76)" "M (var) = 117010.72 (22632032115.99)"
## Colorado Connecticut
## "M (var) = 21612.76 (218786478.31)" "M (var) = 30973.12 (338623054.84)"
## Delaware District of Columbia
## "M (var) = 7412.47 (25597794.30)" "M (var) = 6582.37 (18420990.17)"
## Florida Georgia
## "M (var) = 110458.66 (20549808845.26)" "M (var) = 55826.97 (3117505325.58)"
## Hawaii Idaho
## "M (var) = 754.18 (319784.83)" "M (var) = 5142.22 (37561901.47)"
## Illinois Indiana
## "M (var) = 91234.97 (3835241282.45)" "M (var) = 28995.28 (455362912.32)"
## Iowa Kansas
## "M (var) = 17777.07 (222814033.18)" "M (var) = 9651.47 (72514969.03)"
## Kentucky Louisiana
## "M (var) = 10301.81 (86002307.43)" "M (var) = 42288.85 (1099106001.76)"
## Maine Maryland
## "M (var) = 1950.57 (1789325.31)" "M (var) = 43083.18 (888300841.46)"
## Massachusetts Michigan
## "M (var) = 59075.84 (2332863636.98)" "M (var) = 51031.21 (778118807.94)"
## Minnesota Mississippi
## "M (var) = 20617.55 (349446084.07)" "M (var) = 18080.76 (321366130.14)"
## Missouri Montana
## "M (var) = 15424.34 (200879113.81)" "M (var) = 955.60 (1243847.65)"
## Nebraska Nevada
## "M (var) = 11180.60 (84232938.20)" "M (var) = 12921.20 (208118604.42)"
## New Hampshire New Jersey
## "M (var) = 3452.22 (5750607.37)" "M (var) = 118239.52 (4354790214.04)"
## New Mexico New York
## "M (var) = 7769.04 (42548261.44)" "M (var) = 280235.03 (22915287915.56)"
## North Carolina North Dakota
## "M (var) = 37117.99 (1560810914.28)" "M (var) = 2369.15 (4011375.43)"
## Ohio Oklahoma
## "M (var) = 34658.71 (786174403.82)" "M (var) = 9696.54 (115924395.88)"
## Oregon Pennsylvania
## "M (var) = 5452.09 (31697670.62)" "M (var) = 60222.86 (1496539414.07)"
## Rhode Island South Carolina
## "M (var) = 9471.77 (50947092.84)" "M (var) = 23304.10 (797856878.30)"
## South Dakota Tennessee
## "M (var) = 4188.93 (8936234.68)" "M (var) = 29886.66 (1001545684.73)"
## Texas Utah
## "M (var) = 97499.42 (17596205321.66)" "M (var) = 13108.60 (169495428.35)"
## Vermont Virginia
## "M (var) = 889.22 (181333.38)" "M (var) = 37562.43 (927534909.98)"
## West Virginia Wisconsin
## "M (var) = 2184.65 (3730471.36)" "M (var) = 15173.51 (274313734.01)"
## Wyoming
## "M (var) = 1000.89 (643525.29)"
#check if variance is greater than mean for deaths
with(data, tapply(deaths, Province.State, function(x) {
sprintf("M (var) = %1.2f (%1.2f)", mean(x, na.rm=TRUE), var(x, na.rm=TRUE))
}))## Alabama
## "M (var) = 2.25 (1.64)" "M (var) = 616.07 (248109.11)"
## Alaska Arizona
## "M (var) = 10.77 (38.12)" "M (var) = 877.88 (1195310.29)"
## Arkansas California
## "M (var) = 158.68 (20685.71)" "M (var) = 3057.39 (9334585.18)"
## Colorado Connecticut
## "M (var) = 1043.47 (475901.05)" "M (var) = 2737.37 (3103878.21)"
## Delaware District of Columbia
## "M (var) = 333.79 (53618.15)" "M (var) = 333.37 (52442.10)"
## Florida Georgia
## "M (var) = 2293.38 (4101248.45)" "M (var) = 1683.50 (1550561.07)"
## Hawaii Idaho
## "M (var) = 14.07 (67.71)" "M (var) = 70.68 (2619.70)"
## Illinois Indiana
## "M (var) = 4112.33 (8131003.40)" "M (var) = 1591.56 (1225143.99)"
## Iowa Kansas
## "M (var) = 408.90 (103383.62)" "M (var) = 178.66 (13483.91)"
## Kentucky Louisiana
## "M (var) = 383.06 (57880.22)" "M (var) = 2180.05 (1682368.90)"
## Maine Maryland
## "M (var) = 67.95 (1889.84)" "M (var) = 2029.09 (1641686.33)"
## Massachusetts Michigan
## "M (var) = 4011.56 (12895498.79)" "M (var) = 4223.20 (5733952.34)"
## Minnesota Mississippi
## "M (var) = 798.19 (411078.89)" "M (var) = 648.05 (294810.52)"
## Missouri Montana
## "M (var) = 610.76 (195501.74)" "M (var) = 19.09 (236.90)"
## Nebraska Nevada
## "M (var) = 158.45 (13421.83)" "M (var) = 345.92 (63064.48)"
## New Hampshire New Jersey
## "M (var) = 192.53 (25998.42)" "M (var) = 9599.58 (37868657.99)"
## New Mexico New York
## "M (var) = 293.80 (52011.31)" "M (var) = 21416.60 (156954791.19)"
## North Carolina North Dakota
## "M (var) = 769.47 (422457.19)" "M (var) = 48.43 (1333.15)"
## Ohio Oklahoma
## "M (var) = 1720.73 (1529599.19)" "M (var) = 266.57 (28796.52)"
## Oregon Pennsylvania
## "M (var) = 131.35 (9641.35)" "M (var) = 3957.30 (7948772.71)"
## Rhode Island South Carolina
## "M (var) = 452.47 (155225.49)" "M (var) = 535.97 (250786.18)"
## South Dakota Tennessee
## "M (var) = 53.88 (2039.21)" "M (var) = 381.30 (105856.62)"
## Texas Utah
## "M (var) = 1724.72 (3718210.82)" "M (var) = 108.93 (9195.08)"
## Vermont Virginia
## "M (var) = 42.56 (385.89)" "M (var) = 1055.32 (620667.41)"
## West Virginia Wisconsin
## "M (var) = 59.00 (1603.68)" "M (var) = 392.94 (118066.40)"
## Wyoming
## "M (var) = 12.02 (87.56)"
#check if variance is greater than mean for hospitalizations
with(data, tapply(hospitalizations, Province.State, function(x) {
sprintf("M (var) = %1.2f (%1.2f)", mean(x, na.rm=TRUE), var(x, na.rm=TRUE))
}))## Alabama
## "M (var) = NaN (NA)" "M (var) = 4220.78 (6309089.98)"
## Alaska Arizona
## "M (var) = 53.81 (1009.80)" "M (var) = 10060.00 (NA)"
## Arkansas California
## "M (var) = 1266.76 (562035.98)" "M (var) = NaN (NA)"
## Colorado Connecticut
## "M (var) = 3828.04 (4031948.94)" "M (var) = NaN (NA)"
## Delaware District of Columbia
## "M (var) = NaN (NA)" "M (var) = NaN (NA)"
## Florida Georgia
## "M (var) = 10328.16 (53739973.62)" "M (var) = 8243.03 (24344251.09)"
## Hawaii Idaho
## "M (var) = 89.88 (2282.84)" "M (var) = 306.66 (41890.98)"
## Illinois Indiana
## "M (var) = NaN (NA)" "M (var) = 7316.12 (872382.42)"
## Iowa Kansas
## "M (var) = NaN (NA)" "M (var) = 744.15 (201000.59)"
## Kentucky Louisiana
## "M (var) = 2946.13 (195215.09)" "M (var) = NaN (NA)"
## Maine Maryland
## "M (var) = 260.07 (10366.05)" "M (var) = 7658.91 (16012307.53)"
## Massachusetts Michigan
## "M (var) = 1687.30 (215771.79)" "M (var) = NaN (NA)"
## Minnesota Mississippi
## "M (var) = 2584.30 (2912611.07)" "M (var) = 1842.72 (1175549.93)"
## Missouri Montana
## "M (var) = NaN (NA)" "M (var) = 88.02 (2773.30)"
## Nebraska Nevada
## "M (var) = 1338.55 (52177.03)" "M (var) = NaN (NA)"
## New Hampshire New Jersey
## "M (var) = 448.72 (43888.19)" "M (var) = NaN (NA)"
## New Mexico New York
## "M (var) = 1496.10 (532600.29)" "M (var) = NaN (NA)"
## North Carolina North Dakota
## "M (var) = NaN (NA)" "M (var) = 148.54 (11686.26)"
## Ohio Oklahoma
## "M (var) = 5466.83 (9814958.57)" "M (var) = 1147.75 (665100.00)"
## Oregon Pennsylvania
## "M (var) = 738.48 (159758.80)" "M (var) = NaN (NA)"
## Rhode Island South Carolina
## "M (var) = NaN (NA)" "M (var) = NaN (NA)"
## South Dakota Tennessee
## "M (var) = 428.98 (73272.33)" "M (var) = 1971.84 (1622444.79)"
## Texas Utah
## "M (var) = NaN (NA)" "M (var) = 992.18 (493498.92)"
## Vermont Virginia
## "M (var) = NaN (NA)" "M (var) = 4143.43 (6448855.13)"
## West Virginia Wisconsin
## "M (var) = NaN (NA)" "M (var) = 2553.50 (1478238.41)"
## Wyoming
## "M (var) = 92.00 (1419.42)"
Based on the fact that variance >> mean for all of the variables:
Because of the skew in the distribution in the exposure (cases) and the fact that variance is greater than the mean, I will explore the use of a negative binomial regression to predict daily death count. This has been done in other studies also (https://www.nber.org/papers/w27391, https://jamanetwork.com/journals/jamanetworkopen/article-abstract/2779417) I take into account that there is an offset term given by the log of the population, and that daily deaths will vary by location_id/state.
I have plotted the actual vs predicted daily death count for 3 negative binomial regression models, with death as the outcome. In the first model, I use only cases as a predictor with an offset population term, while accounting for state level differences. The prediction varies by state.
In the second model, I use cases and hospitalizations as the predictor, while accounting for state level differences. The third model only uses cases as the predictor with an offset population term.
library(pscl)
library(lme4)
#negative binomial regression
#simple model, only cases as the predictor
model1 = MASS::glm.nb(formula = deaths ~ cases + factor(location_id) + offset(log(population)),
data = data,
na.action = na.exclude)
#includes hospitalizations as an additional predictor
model2 = MASS:: glm.nb(formula = deaths ~ cases + hospitalizations + factor(location_id) + offset(log(population)),
data = data,
na.action = na.exclude)
#most parsimonious model, does not account for state
model3 = MASS::glm.nb(formula = deaths ~ cases + offset(log(population)),
data = data,
na.action = na.exclude)
data$predicted_cases1 <- predict(model1, type = "response")
data$predicted_cases2 <- predict(model2, type = "response")
data$predicted_cases3 <- predict(model3, type = "response")
#actual no of cases (x) vs predicted (y)
ggplot(data, aes(x = cases, y = predicted_cases1, color=as.factor(Province.State))) +
scale_color_viridis(discrete=TRUE, option = "D")+
geom_point(alpha = 0.5) +
labs(title = "Model 1: deaths ~ cases",
x = "Actual Cases",
y = "Predicted Cases") +
theme_minimal() +
theme(plot.title = element_text(color="black", size=10, face="bold.italic"))ggplot(data, aes(x = cases, y = predicted_cases2, color=as.factor(Province.State))) +
scale_color_viridis(discrete=TRUE, option = "D")+
geom_point(alpha = 0.5) +
labs(title = "Model 2: deaths ~ cases + hospitalizations",
x = "Actual Cases",
y = "Predicted Cases") +
theme_minimal() +
theme(plot.title = element_text(color="black", size=10, face="bold.italic"))ggplot(data, aes(x = cases, y = predicted_cases3, color=as.factor(Province.State))) +
scale_color_viridis(discrete=TRUE, option = "D")+
geom_point(alpha = 0.5) +
labs(title = "Model 3: deaths ~ cases (no state)",
x = "Actual Cases",
y = "Predicted Cases") +
theme_minimal() +
theme(plot.title = element_text(color="black", size=10, face="bold.italic"))I decided to use ARIMA to perform time-series forecasting due to several reasons. ARIMA is AutoRegressive Integrated Moving Average, specified by an autoregressive component that refers to the use of past values, difference of observations to make the time series stationary, and a moving average component that shows the error of the model as a combination of prior error term.
First, ARIMA is better suited to handle nonstationarity compared to other linear regression methods Second, ARIMA is easier to use than classical nonlinear epidemiological models such as SIR or SEIR for which no closed form exact solutions exists, and numerical simulations are required [Barlow NS, Weinstein SJ. Accurate closed-form solution of the SIR epidemic model. Physica D. 2020;408:132540. doi: 10.1016/j.physd.2020.132540.]. A recent study also showed that ARIMA models outperformed SIR in predicting COVID-19 cases [Abuhasel KA, Khadr M, Alquraish MM. Analyzing and forecasting COVID-19 pandemic in the Kingdom of Saudi Arabia using ARIMA and SIR models. Comput Intell. 2022;38:770–783. doi: 10.1111/coin.12407, Abolmaali S, Shirzaei S. A comparative study of SIR Model, Linear Regression, Logistic Function and ARIMA Model for forecasting COVID-19 cases. AIMS Public Health. 2021;8:598–613. doi: 10.3934/publichealth.2021048.].
Benefits: If the time series data is stationary, the ARIMA model works well and produces accurate forecasts.
Limitations: ARIMA also doesn’t work if the residuals are not normally distributed, and if the residuals are correlated. ARIMA also assumes constant mean and variance over time. ARIMA is very sensitive to missing data, as we have seen. ARIMA does not work on missing data.
## Warning: package 'forecast' was built under R version 4.3.3
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
library(dplyr)
#Create a state-specific 14 day forecast for cases, deaths, hospitalization
states = c("Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado",
"Connecticut", "Delaware", "Florida", "Georgia", "Hawaii", "Idaho",
"Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", "Louisiana",
"Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota",
"Mississippi", "Missouri", "Montana", "Nebraska", "Nevada",
"New Hampshire", "New Jersey", "New Mexico", "New York", "North Carolina",
"North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania",
"Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas",
"Utah", "Vermont", "Virginia", "West Virginia",
"Wisconsin", "Wyoming", "District of Columbia")
##CASES
forecast_by_state_cases <- function(data) {
for (state in states) {
# Filter data for the current state
data_subset <- data %>%
filter(Province.State == state) %>%
select(cases)
# Fit ARIMA model and forecast
fit <- auto.arima(data_subset$cases)
forecastedValues <- forecast(fit, 14)
# Plot the forecast
plot(forecastedValues, main = paste("Forecast for", state),
col.main = "darkgreen",
xlab = "Time in days",
ylab = "Cases")
}
}
forecast_by_state_cases(data)##DEATHS
forecast_by_state_deaths <- function(data) {
for (state in states) {
# Filter data for the current state
data_subset <- data %>%
filter(Province.State == state) %>%
select(deaths)
# Fit ARIMA model and forecast
fit <- auto.arima(data_subset$deaths)
forecastedValues <- forecast(fit, 14)
# Plot the forecast
plot(forecastedValues, main = paste("Forecast for", state),
col.main = "purple",
xlab = "Time in days",
ylab = "Deaths")
}
}
forecast_by_state_deaths(data)##HOSPITALIZATIONS
#subsetting data to contain states containing no missing hospitalization data
hosp_data = data %>%
filter(complete.cases(hospitalizations))
state_hosp_name = c("Alabama", "Alaska","Arizona","Arkansas","Colorado", "Florida", "Georgia","Hawaii","Idaho", "Indiana","Kansas","Kentucky","Maine","Maryland", "Massachusetts", "Minnesota","Mississippi","Montana","Nebraska", "New Hampshire", "New Mexico" ,"North Dakota","Ohio", "Oklahoma","Oregon", "South Dakota","Tennessee","Utah","Virginia", "Wisconsin", "Wyoming")
forecast_by_state_hosps <- function(data) {
for (state in state_hosp_name) {
# Filter data for the current state
data_subset <- data %>%
filter(Province.State == state) %>%
select(hospitalizations)
# Fit ARIMA model and forecast
fit <- auto.arima(data_subset$hospitalizations)
forecastedValues <- forecast(fit, 14)
# Plot the forecast
plot(forecastedValues, main = paste("Forecast for", state),
col.main = "darkblue",
xlab = "Time in days",
ylab = "Hospitalizations")
}
}
forecast_by_state_hosps(hosp_data)One assumption made is that historical data is used to make predictions, which may not be entirely true for epidemic modelling, as other factors such as new variants, changing climates, vaccination trends, and others can all affect trajectories. Inter-state dynamics or relationships between the different states could also impact these trajectories. If I had more time, I would look into other confounding factors and include them in a model to a make more accurate predictions.
Assuming we have access to other data sources, I would like to use recurrent neural networks to do forecasting. Currently we are doing univariate forecasting, neural networks (such as Transformers (https://arxiv.org/abs/1706.03762)) will allow multivariate forecasting. Another benefit of Transformers over the other architectures is that we can incorporate missing values (which are common in the time series setting). However, one drawback is that training these models require access to large datasets and are computationally expensive.